package nsalt.func

import chisel3._
import chisel3.util._
import chisel3.util.experimental.BoringUtils

import chisel3.experimental.ChiselEnum
// import chisel3.experimental.suppressEnumCastWarning

import nsalt._
import nsalt.bus._
import nsalt.util._


trait LoadStoreConst {
  val INDEP_ADDR_CALC = false
  val MEM_ORDERING_QUEUE_SIZE = 8
  val STORE_QUEUE_SIZE = 8
}

object LoadStoreState extends ChiselEnum {
 
  val IDLE  = Value
  val EXEC  = Value  
  val LOAD  = Value  
  val LR    = Value  
  val SC    = Value  
  val AMO_L = Value  
  val AMO_A = Value 
  val AMO_S = Value  

}

object LoadStoreFSM {

  def apply(
    valid: Bool, 
    amoReq: Bool, 
    lrReq: Bool, 
    scReq: Bool,
    scInvalid: Bool,
    ioFire: Bool, 
    execFire: Bool
  ) = {

    val state = RegInit(LoadStoreState.IDLE)

    // State transitions
    when (state === LoadStoreState.IDLE) {
      when (valid) {
        state := LoadStoreState.EXEC 
      }
      when (amoReq) {
        state := LoadStoreState.AMO_L
      }
      when (lrReq) {
        state := LoadStoreState.LR 
      }
      when (scReq) {
        state := Mux(scInvalid, LoadStoreState.IDLE, LoadStoreState.SC)
      }
    }

    when(ioFire && (state === LoadStoreState.EXEC)) {
      state := LoadStoreState.IDLE 
    }

    when (execFire && (state === LoadStoreState.AMO_L)) {
      state := LoadStoreState.AMO_A
    }

    when (state === LoadStoreState.AMO_A) {
      state := LoadStoreState.AMO_S 
    }

    when (execFire && (state === LoadStoreState.AMO_S)) {
      state := LoadStoreState.IDLE 
    }

    when (execFire && ((state === LoadStoreState.SC) || (state === LoadStoreState.LR))) {
      when (execFire) {
        state := LoadStoreState.IDLE 
      }
    }

    state
  }
}

// LSU, Load/Store Unit, Found at: 
// https://github.com/OSCPU/NutShell/blob/fd86beadfc47f52973270ce6109edebd2a30363b/src/main/scala/nutcore/backend/fu/UnpipelinedLSU.scala#L36// 
// Docs:
// https://oscpu.gitbook.io/nutshell/gong-neng-bu-jian-she-ji-xi-jie/lsu

class LoadStore extends Module with Config with LoadStoreConst {

  val io = IO(new LoadStorePort)

  val valid = io.in.valid
  val src1 = io.in.bits.src1
  val src2 = io.in.bits.src2
  val oper = io.in.bits.oper
  
  def access(valid: Bool, src1: UInt, src2: UInt, oper: UInt, dtlbPageFault: Bool): UInt = {
    this.valid := valid
    this.src1 := src1
    this.src2 := src2
    this.oper := oper
    dtlbPageFault := io.dtlbPageFault
    io.out.bits
  }

  val loadStoreExec = Module(new LoadStoreExec)
  loadStoreExec.io.instr := DontCare
  io.dtlbPageFault := loadStoreExec.io.dtlbPageFault

  val storeReq = valid & LoadStoreOperType.isStore(oper)
  val loadReq  = valid & LoadStoreOperType.isLoad(oper)
  val atomReq  = valid & LoadStoreOperType.isAtom(oper)
  val amoReq   = valid & LoadStoreOperType.isAMO(oper)
  val lrReq    = valid & LoadStoreOperType.isLR(oper)
  val scReq    = valid & LoadStoreOperType.isSC(oper)


  // Check out the RISC-V Spec Chapter 7.2: LR/SC Instructions
  // https://riscv.org/wp-content/uploads/2017/05/riscv-spec-v2.2.pdf
  val aq = io.instr(26)
  val rl = io.instr(25)
  val funct3 = io.instr(14, 12)

  val isAtomicWordSingle = !funct3(0)
  val isAtomicWordDouble = funct3(0)

  // Atom LR/SC Control Bits
  val setLr = Wire(Bool())
  val setLrVal = Wire(Bool())
  val setLrAddr = Wire(UInt(ADDR_BITS.W))

  // LR mark and reserved address is provided by CSR
  val lr = WireInit(Bool(), false.B)
  val lrAddr = WireInit(UInt(ADDR_BITS.W), DontCare)

  // BoringUtils.addSource(setLr, "set_lr")
  // BoringUtils.addSource(setLrVal, "set_lr_val")
  // BoringUtils.addSource(setLrAddr, "set_lr_addr")

  // BoringUtils.addSink(lr, "lr")
  // BoringUtils.addSink(lrAddr, "lr_addr")

  val scInvalid = !(src1 === lrAddr) && scReq

  // PF signal from TLB
  val dtlbFinish = WireInit(false.B)
  val dtlbPageFault = WireInit(false.B)
  val dtlbEnable = WireInit(false.B)

  // BoringUtils.addSink(dtlbFinish, "DTLBFINISH")
  // BoringUtils.addSink(dtlbPageFault, "DTLBPF")
  // BoringUtils.addSink(dtlbEnable, "DTLBENABLE")

  // LSU control FSM
  val atomicMem = Reg(UInt(XLEN.W))
  val atomicReg = Reg(UInt(XLEN.W))
  val atomic = Module(new AtomicArithLogic)

  atomic.io.src1 := atomicMem
  atomic.io.src2 := io.dataW
  atomic.io.oper := oper
  atomic.io.isWordOp := isAtomicWordSingle
  
  // StoreQueue
  // TODO: inst fence needs storeQueue to be finished
  // val enableStoreQueue = EnableStoreQueue // StoreQueue is disabled for page fault detection
  // if(enableStoreQueue){
  //   val storeQueue = Module(new Queue(new StoreQueueEntry, 4))
  //   storeQueue.io.enq.valid := state === LoadStoreState.IDLE && storeReq
  //   storeQueue.io.enq.bits.src1 := src1
  //   storeQueue.io.enq.bits.src2 := src2
  //   storeQueue.io.enq.bits.dataW := io.dataW
  //   storeQueue.io.enq.bits.oper := oper
  //   storeQueue.io.deq.ready := loadStoreExec.io.out.fire
  // }
  

  loadStoreExec.io.in.valid     := false.B
  loadStoreExec.io.out.ready    := DontCare
  loadStoreExec.io.in.bits.src1 := DontCare
  loadStoreExec.io.in.bits.src2 := DontCare
  loadStoreExec.io.in.bits.oper := DontCare
  loadStoreExec.io.dataW        := DontCare
  io.out.valid               := false.B
  io.in.ready                := false.B

  val state = LoadStoreFSM(
    valid = io.in.valid, 
    amoReq, 
    lrReq, 
    scReq,
    scInvalid,
    ioFire = io.out.fire, 
    execFire = loadStoreExec.io.out.fire
  )

  val addr = if (INDEP_ADDR_CALC) {
    RegNext(src1 + src2, state === LoadStoreState.IDLE)
  } else {
    DontCare
  }

  val atomicLoadOper = Mux(isAtomicWordDouble, LoadStoreOperType.ld, LoadStoreOperType.lw)
  val atomicStoreOper = Mux(isAtomicWordDouble, LoadStoreOperType.sd, LoadStoreOperType.sw)

  switch (state) {
    is(LoadStoreState.IDLE){ // calculate address 
      loadStoreExec.io.in.valid     := false.B
      loadStoreExec.io.out.ready    := DontCare 
      loadStoreExec.io.in.bits.src1 := DontCare
      loadStoreExec.io.in.bits.src2 := DontCare
      loadStoreExec.io.in.bits.oper := DontCare
      loadStoreExec.io.dataW        := DontCare
      io.in.ready                := false.B || scInvalid
      io.out.valid               := false.B || scInvalid

      if(!INDEP_ADDR_CALC){
        loadStoreExec.io.in.valid     := io.in.valid && !atomReq
        loadStoreExec.io.out.ready    := io.out.ready 
        loadStoreExec.io.in.bits.src1 := src1 + src2
        loadStoreExec.io.in.bits.src2 := DontCare
        loadStoreExec.io.in.bits.oper := oper
        loadStoreExec.io.dataW        := io.dataW
        io.in.ready                := loadStoreExec.io.out.fire || scInvalid
        io.out.valid               := loadStoreExec.io.out.valid  || scInvalid
      }
    } 

    is(LoadStoreState.EXEC){
      loadStoreExec.io.in.valid     := true.B
      loadStoreExec.io.out.ready    := io.out.ready 
      loadStoreExec.io.in.bits.src1 := addr
      loadStoreExec.io.in.bits.src2 := DontCare
      loadStoreExec.io.in.bits.oper := oper
      loadStoreExec.io.dataW        := io.dataW
      io.in.ready                := loadStoreExec.io.out.fire 
      io.out.valid               := loadStoreExec.io.out.valid  
    }

    // is(s_load){
    //   loadStoreExec.io.in.valid     := true.B
    //   loadStoreExec.io.out.ready    := io.out.ready 
    //   loadStoreExec.io.in.bits.src1 := src1
    //   loadStoreExec.io.in.bits.src2 := src2
    //   loadStoreExec.io.in.bits.oper := oper
    //   loadStoreExec.io.dataW        := DontCare
    //   io.in.ready                := loadStoreExec.io.out.fire
    //   io.out.valid               := loadStoreExec.io.out.valid
    //   when(loadStoreExec.io.out.fire){state := LoadStoreState.IDLE}//load finished
    // }

    is(LoadStoreState.AMO_L){
      loadStoreExec.io.in.valid     := true.B
      loadStoreExec.io.out.ready    := true.B 
      loadStoreExec.io.in.bits.src1 := src1
      loadStoreExec.io.in.bits.src2 := DontCare
      loadStoreExec.io.in.bits.oper := atomicLoadOper
      loadStoreExec.io.dataW        := DontCare
      io.in.ready                := false.B
      io.out.valid               := false.B

      atomicMem := loadStoreExec.io.out.bits
      atomicReg := loadStoreExec.io.out.bits
    }

    is(LoadStoreState.AMO_A){
      loadStoreExec.io.in.valid     := false.B
      loadStoreExec.io.out.ready    := false.B 
      loadStoreExec.io.in.bits.src1 := DontCare
      loadStoreExec.io.in.bits.src2 := DontCare
      loadStoreExec.io.in.bits.oper := DontCare
      loadStoreExec.io.dataW        := DontCare
      io.in.ready                := false.B
      io.out.valid               := false.B

      atomicMem := atomic.io.result
    }

    is(LoadStoreState.AMO_S){
      loadStoreExec.io.in.valid     := true.B
      loadStoreExec.io.out.ready    := io.out.ready
      loadStoreExec.io.in.bits.src1 := src1
      loadStoreExec.io.in.bits.src2 := DontCare
      loadStoreExec.io.in.bits.oper := atomicStoreOper
      loadStoreExec.io.dataW        := atomicMem
      io.in.ready                := loadStoreExec.io.out.fire
      io.out.valid               := loadStoreExec.io.out.fire
    }

    is(LoadStoreState.LR){
      loadStoreExec.io.in.valid     := true.B
      loadStoreExec.io.out.ready    := io.out.ready
      loadStoreExec.io.in.bits.src1 := src1
      loadStoreExec.io.in.bits.src2 := DontCare
      loadStoreExec.io.in.bits.oper := Mux(isAtomicWordDouble, LoadStoreOperType.ld, LoadStoreOperType.lw)
      loadStoreExec.io.dataW        := DontCare
      io.in.ready                := loadStoreExec.io.out.fire
      io.out.valid               := loadStoreExec.io.out.fire
    }

    is(LoadStoreState.SC){
      loadStoreExec.io.in.valid     := true.B
      loadStoreExec.io.out.ready    := io.out.ready
      loadStoreExec.io.in.bits.src1 := src1
      loadStoreExec.io.in.bits.src2 := DontCare
      loadStoreExec.io.in.bits.oper := Mux(isAtomicWordDouble, LoadStoreOperType.sd, LoadStoreOperType.sw)
      loadStoreExec.io.dataW        := io.dataW
      io.in.ready                := loadStoreExec.io.out.fire
      io.out.valid               := loadStoreExec.io.out.fire
    }
  }

  when(dtlbPageFault || io.loadAddrMisaligned || io.storeAddrMisaligned){
    state := LoadStoreState.IDLE
    io.out.valid := true.B
    io.in.ready := true.B
  }

  // controled by FSM 
  // io.in.ready := loadStoreExec.io.in.ready
  // loadStoreExec.io.dataW := io.dataW
  // io.out.valid := loadStoreExec.io.out.valid 

  //Set LR/SC bits
  setLr := io.out.fire && (lrReq || scReq)
  setLrVal := lrReq
  setLrAddr := src1

  io.dmem <> loadStoreExec.io.dmem
  io.out.bits := Mux(scReq, 
    scInvalid, 
    Mux(state === LoadStoreState.AMO_S,
      atomicReg,
      loadStoreExec.io.out.bits
    )
  )

  val lsuMMIO = WireInit(false.B)
  // BoringUtils.addSink(lsuMMIO, "lsuMMIO")

  val mmioReg = RegInit(false.B)
  when (!mmioReg) {
    mmioReg := lsuMMIO
  }
  when (io.out.valid) {
    mmioReg := false.B
  }

  io.isMMIO := mmioReg && io.out.valid

  io.loadAddrMisaligned  := loadStoreExec.io.loadAddrMisaligned
  io.storeAddrMisaligned := loadStoreExec.io.storeAddrMisaligned
}

